{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "653089c9",
   "metadata": {},
   "source": [
    "# Chapter 8: Text Generation with Recurrent Neural Networks (RNNs)\n",
    "\n",
    "\n",
    "This chapter covers\n",
    "\n",
    "* The idea behind RNNs and why they can handle sequential data \n",
    "* Character tokenization, word tokenization, and subword tokenization\n",
    "* How word embedding works\n",
    "* Building and training an RNN to generate text \n",
    "* Using temperature and top-K sampling to control the creativeness of text generation\n",
    "\n",
    "So far in this book, we have discussed how to generate shapes, numbers, and images. Starting from this chapter, we’ll focus mainly on text generation. Generating text is often considered the holy grail of generative AI for several compelling reasons. Human language is incredibly complex and nuanced. It involves not only understanding grammar and vocabulary but also context, tone, and cultural references. Successfully generating coherent and contextually appropriate text is a significant challenge that requires deep understanding and processing of language. \n",
    "\n",
    "As humans, we primarily communicate through language. AI that can generate human-like text can interact more naturally with users, making technology more accessible and user-friendly. Text generation has many applications, from automating customer service responses to creating entire articles, scripting for games and movies, aiding in creative writing, and even building personal assistants. The potential impact across industries is enormous.\n",
    "\n",
    "In this chapter, we’ll make our first attempt at building and training models to generate text. You’ll learn to tackle three main challenges in modeling text generation. First, text is sequential data, consisting of data points organized in a specific sequence, where each point is successively ordered to reflect the inherent order and interdependencies within the data. Predicting outcomes for sequences is challenging due to their sensitive ordering. Altering the sequence of elements changes their meaning. Secondly, text exhibits long-range dependencies: the meaning of a certain part of the text depends on elements that appeared much earlier in the text (e.g., 200 words ago). Understanding and modeling these long-range dependencies is essential for generating coherent text. Lastly, human language is ambiguous and context dependent. Training a model to understand nuances, sarcasm, idioms, and cultural references to generate contextually accurate text is challenging.\n",
    "\n",
    "You'll explore a specific neural network designed for handling sequential data, such as text or time series: the recurrent neural network (RNN). Traditional neural networks, such as feedforward neural networks or fully connected networks, treat each input independently. This means that the network processes each input separately, without considering any relationship or order between different inputs. In contrast, RNNs are specifically designed to handle sequential data. In an RNN, the output at a given time step depends not only on the current input but also on previous inputs. This allows RNNs to maintain a form of memory, capturing information from previous time steps to influence the processing of the current input. \n",
    "This sequential processing makes RNNs suitable for tasks where the order of the inputs matters, such as language modeling, where the goal is to predict the next word in a sentence based on previous words. We’ll focus on one variant of RNN, Long Short-Term Memory (LSTM) networks, which can recognize both short-term and long-term data patterns in sequential data like text. LSTM models use a hidden state to capture information in previous time steps. Therefore, a trained LSTM model can produce coherent text based on the context. \n",
    "\n",
    "The style of the generated text depends on the training data. Additionally, as we plan to train a model from scratch for text generation, text length is a crucial factor. It needs to be sufficiently extensive for the model to effectively learn and mimic a particular writing style, yet concise enough to avoid excessive computational demands during training. As a result, we’ll use the text from the novel Anna Karenina to train an LSTM model. Since neural networks like an LSTM cannot accept text as input directly, you’ll learn to break down text into tokens (individual words in this chapter; but can be parts of words, as you’ll see in later chapters), a process known as tokenization. You’ll then create a dictionary to map each unique token into an integer (i.e., an index). Based on this dictionary, you’ll convert the text into a long sequence of integers, ready to be fed into a neural network. \n",
    "\n",
    "You’ll use sequences of indexes of a certain length as the input to train the LSTM model. You shift the sequence of inputs by one token to the right and use it as the output: you are effectively training the model to predict the next word in a sentence. This is the so-called sequence-to-sequence prediction problem in natural language processing (NLP) and you’ll see it again in later chapters. \n",
    "\n",
    "Once the LSTM is trained, you’ll use it to generate text one token at a time based on previous tokens in the sequence as follows: You feed a prompt (part of a sentence such as “Anna and the”) to the trained model. The model then predicts the most likely next token and appends the selected token to your prompt. The updated prompt serves again as the input and the model is used once more to predict the next token. The iterative process continues until the prompt reaches a certain length. This approach is similar to the mechanism employed by more advanced generative models like ChatGPT (though ChatGPT is not an LSTM). You’ll witness the trained LSTM model generating grammatically correct and coherent text, with a style matching that of the original novel. \n",
    "\n",
    "Finally, you also learn how to control the creativeness of the generated text by using temperature and top-K sampling. Temperature controls the randomness of the predictions of the trained model. A high temperature makes the generated text more creative while a low temperature makes the text more confident and predictable. Top-K sampling is a method where you select the next token from the top K most probable tokens, rather than selecting from the entire vocabulary. A small value of K leads to the selection of highly likely tokens in each step and this, in turn, makes the generated text less creative and more coherent.\n",
    "The primary goal of this chapter is not necessarily to generate the most coherent text possible, which, as mentioned earlier, presents substantial challenges. Instead, our objective is to demonstrate the limitations of RNNs, thereby setting the stage for the introduction of Transformers in subsequent chapters. More importantly, this chapter establishes the basic principles of text generation, including tokenization, word embedding, sequence prediction, temperature settings, and top-K sampling. Consequently, in the later chapters, you will have a solid understanding of the fundamentals of NLP. This foundation will allow us to concentrate on other, more advanced, aspects of NLP, such as how the attention mechanism functions and the architecture of Transformers."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "30d1c19a",
   "metadata": {},
   "source": [
    "# 1.    Introduction to recurrent neural networks (RNNs)\n",
    "# 2.\tFundamentals of Natural Language Processing (NLP)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "984fd7a1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['I', 't', ' ', 'i', 's', ' ', 'u', 'n', 'b', 'e', 'l', 'i', 'e', 'v', 'a', 'b', 'l', 'y', ' ', 'g', 'o', 'o', 'd', '!']\n"
     ]
    }
   ],
   "source": [
    "# character tokenization\n",
    "text=\"It is unbelievably good!\"\n",
    "tokens=list(text)\n",
    "print(tokens)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "b3f2efd1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['H', 'i', ',', ' ', 't', 'h', 'e', 'r', 'e', '!']\n"
     ]
    }
   ],
   "source": [
    "# exercise 8.1\n",
    "text=\"Hi, there!\"\n",
    "tokens=list(text)\n",
    "print(tokens)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "80157dc1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['It', 'is', 'unbelievably', 'good', '!']\n"
     ]
    }
   ],
   "source": [
    "# word tokenization\n",
    "text=\"It is unbelievably good!\"\n",
    "text=text.replace(\"!\",\" !\")\n",
    "tokens=text.split(\" \")\n",
    "print(tokens)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "2b2ec351",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['Hi', ',', 'there', '!']\n"
     ]
    }
   ],
   "source": [
    "# exercise 8.2\n",
    "text=\"Hi, there!\"\n",
    "for x in list(\",!\"):\n",
    "    text=text.replace(f\"{x}\",f\" {x}\")\n",
    "tokens=text.split(\" \")\n",
    "print(tokens)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c268aecc",
   "metadata": {},
   "source": [
    "# 3.\tPrepare data to train the LSTM model\n",
    "We'll use the text file of Anna Karenina in one of Carlos Lara's GitHub repositories. Go to the link https://github.com/LeanManager/NLP-PyTorch/tree/master/data to download the text file and save it as *anna.txt* in the folder /files/ on your computer. After that, open the file and delete everything after line 39888, which says \"END OF THIS PROJECT GUTENBERG EBOOK ANNA KARENINA.\" "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ed84eadd",
   "metadata": {},
   "source": [
    "## 3.1\tDownload the clean up the text\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "fbfb8bd9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['Chapter', '1\\n\\n\\nHappy', 'families', 'are', 'all', 'alike;', 'every', 'unhappy', 'family', 'is', 'unhappy', 'in', 'its', 'own\\nway.\\n\\nEverything', 'was', 'in', 'confusion', 'in', 'the', \"Oblonskys'\"]\n"
     ]
    }
   ],
   "source": [
    "with open(\"files/anna.txt\",\"r\") as f:\n",
    "    text=f.read()\n",
    "words=text.split(\" \") \n",
    "print(words[:20])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "c57af8c2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'f', 'd', '\"', 'h', 'b', '1', '3', '6', '8', 'o', 'i', '_', 'r', ',', ')', '9', 'a', 'j', '`', '5', '0', 'y', 'q', 'x', '2', 'g', 'z', 'p', \"'\", '?', 'm', 'v', ' ', '\\n', 'e', '-', ':', 'w', 'n', '.', ';', '4', 'k', '(', 'u', '!', '7', 'c', 's', 't', 'l'}\n"
     ]
    }
   ],
   "source": [
    "print(set(text.lower()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "4d463840",
   "metadata": {},
   "outputs": [],
   "source": [
    "clean_text=text.lower().replace(\"\\n\", \" \")\n",
    "clean_text=clean_text.replace(\"-\", \" \")\n",
    "for x in \",.:;?!$()/_&%*@'`\":\n",
    "    clean_text=clean_text.replace(f\"{x}\", f\" {x} \")\n",
    "clean_text=clean_text.replace('\"', ' \" ') \n",
    "text=clean_text.split()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "47d983a2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[',', '.', 'the', '\"', 'and', 'to', 'of', 'he', \"'\", 'a']\n"
     ]
    }
   ],
   "source": [
    "from collections import Counter   \n",
    "word_counts = Counter(text)    \n",
    "\n",
    "# get unique words\n",
    "words=sorted(word_counts, key=word_counts.get,\n",
    "                      reverse=True) \n",
    "print(words[:10])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "ecca0e15",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "the text contains 437098 words\n",
      "there are 12778 unique tokens\n",
      "{',': 0, '.': 1, 'the': 2, '\"': 3, 'and': 4, 'to': 5, 'of': 6, 'he': 7, \"'\": 8, 'a': 9}\n",
      "{0: ',', 1: '.', 2: 'the', 3: '\"', 4: 'and', 5: 'to', 6: 'of', 7: 'he', 8: \"'\", 9: 'a'}\n"
     ]
    }
   ],
   "source": [
    "text_length=len(text)\n",
    "num_unique_words=len(words)\n",
    "print(f\"the text contains {text_length} words\")\n",
    "print(f\"there are {num_unique_words} unique tokens\")  \n",
    "word_to_int={v:k for k,v in enumerate(words)} \n",
    "int_to_word={k:v for k,v in enumerate(words)}\n",
    "print({k:v for k,v in word_to_int.items() if k in words[:10]})\n",
    "print({k:v for k,v in int_to_word.items() if v in words[:10]})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "c1f21f00",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['chapter', '1', 'happy', 'families', 'are', 'all', 'alike', ';', 'every', 'unhappy', 'family', 'is', 'unhappy', 'in', 'its', 'own', 'way', '.', 'everything', 'was']\n",
      "[208, 2755, 280, 2981, 83, 31, 2419, 35, 202, 685, 362, 38, 685, 10, 236, 147, 166, 1, 149, 12]\n"
     ]
    }
   ],
   "source": [
    "print(text[0:20])\n",
    "wordidx=[word_to_int[w] for w in text]  \n",
    "print([word_to_int[w] for w in text[0:20]])  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "38ac6e72-fab4-45e2-87eb-2e82caa63088",
   "metadata": {},
   "outputs": [],
   "source": [
    "# exercise 8.3\n",
    "print(word_to_int[\"anna\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d8e0b2e3",
   "metadata": {},
   "source": [
    "## 3.2\tCreate batches of training data\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "c4823d8c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "seq_len=100  \n",
    "xys=[]\n",
    "for n in range(0, len(wordidx)-seq_len-1):\n",
    "    x = wordidx[n:n+seq_len]\n",
    "    y = wordidx[n+1:n+seq_len+1]\n",
    "    xys.append((torch.tensor(x),(torch.tensor(y))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "08062972",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[  39,   31,    2,  ...,  688,  142,    7],\n",
      "        [ 156, 5293,    0,  ...,   38,  330,    0],\n",
      "        [   3,   97,    0,  ...,    0, 1774,   34],\n",
      "        ...,\n",
      "        [  16,  156,    9,  ...,  113,    5,  533],\n",
      "        [   3,    4,   31,  ...,   98,    5,   98],\n",
      "        [ 289,   19,   23,  ...,    9,  828,  550]])\n",
      "tensor([[  31,    2, 2727,  ...,  142,    7,    0],\n",
      "        [5293,    0,   16,  ...,  330,    0,    3],\n",
      "        [  97,    0,    4,  ..., 1774,   34,    3],\n",
      "        ...,\n",
      "        [ 156,    9,  489,  ...,    5,  533,   27],\n",
      "        [   4,   31,   25,  ...,    5,   98,    1],\n",
      "        [  19,   23,    1,  ...,  828,  550,    1]])\n",
      "torch.Size([32, 100]) torch.Size([32, 100])\n"
     ]
    }
   ],
   "source": [
    "from torch.utils.data import DataLoader\n",
    "\n",
    "torch.manual_seed(42)\n",
    "batch_size=32\n",
    "loader = DataLoader(xys, batch_size=batch_size, shuffle=True)\n",
    "\n",
    "x,y=next(iter(loader))\n",
    "print(x)\n",
    "print(y)\n",
    "print(x.shape,y.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "39e27a43",
   "metadata": {},
   "source": [
    "# 4. Build and Train the LSTM Model\n",
    "\n",
    "\n",
    "## 4.1\tBuild an LSTM model\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "730313c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "device=\"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "class WordLSTM(nn.Module):\n",
    "    def __init__(self, input_size=128, n_embed=128,\n",
    "             n_layers=3, drop_prob=0.2):\n",
    "        super().__init__()\n",
    "        self.input_size=input_size\n",
    "        self.drop_prob = drop_prob\n",
    "        self.n_layers = n_layers\n",
    "        self.n_embed = n_embed\n",
    "        vocab_size=len(word_to_int)\n",
    "        self.embedding=nn.Embedding(vocab_size,n_embed)\n",
    "        self.lstm = nn.LSTM(input_size=self.input_size,\n",
    "            hidden_size=self.n_embed,\n",
    "            num_layers=self.n_layers,\n",
    "            dropout=self.drop_prob,batch_first=True)\n",
    "        self.fc = nn.Linear(input_size, vocab_size)\n",
    "    def forward(self, x, hc):\n",
    "        embed=self.embedding(x)\n",
    "        x, hc = self.lstm(embed, hc)\n",
    "        x = self.fc(x)\n",
    "        return x, hc        \n",
    "    def init_hidden(self, n_seqs):\n",
    "        weight = next(self.parameters()).data\n",
    "        return (weight.new(self.n_layers,\n",
    "                           n_seqs, self.n_embed).zero_(),\n",
    "                weight.new(self.n_layers,\n",
    "                           n_seqs, self.n_embed).zero_()) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "0d429629",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WordLSTM(\n",
      "  (embedding): Embedding(12778, 128)\n",
      "  (lstm): LSTM(128, 128, num_layers=3, batch_first=True, dropout=0.2)\n",
      "  (fc): Linear(in_features=128, out_features=12778, bias=True)\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "model=WordLSTM().to(device)\n",
    "print(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "8c746b60",
   "metadata": {},
   "outputs": [],
   "source": [
    "lr=0.0001\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=lr)\n",
    "loss_func = nn.CrossEntropyLoss()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4bd8eed1",
   "metadata": {},
   "source": [
    "## 4.2\tTrain the LSTM model\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "fe819235",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.train()\n",
    "\n",
    "for epoch in range(50):\n",
    "    tloss=0\n",
    "    sh,sc = model.init_hidden(batch_size)\n",
    "    for i, (x,y) in enumerate(loader):    \n",
    "        if x.shape[0]==batch_size:\n",
    "            inputs, targets = x.to(device), y.to(device)\n",
    "            optimizer.zero_grad()\n",
    "            output, (sh,sc) = model(inputs, (sh,sc))\n",
    "            loss = loss_func(output.transpose(1,2),targets)\n",
    "            sh,sc=sh.detach(),sc.detach()\n",
    "            loss.backward()\n",
    "            nn.utils.clip_grad_norm_(model.parameters(), 5)\n",
    "            optimizer.step()\n",
    "            tloss+=loss.item()\n",
    "        if (i+1)%1000==0:\n",
    "            print(f\"at epoch {epoch} iteration {i+1}\\\n",
    "            average loss = {tloss/(i+1)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "d9effa9c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "\n",
    "torch.save(model.state_dict(),\"files/wordLSTM.pth\")\n",
    "with open(\"files/word_to_int.p\",\"wb\") as fb:    \n",
    "    pickle.dump(word_to_int, fb) "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "441286c7",
   "metadata": {},
   "source": [
    "# 5\tGenerate text with the trained LSTM model\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ee861399",
   "metadata": {},
   "source": [
    "## 5.1\tGenerate text by predicting the next token"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "4f972d36",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Skip this if you have trained your own model\n",
    "# If you use my model, you should run this code cell\n",
    "# download and unzip https://gattonweb.uky.edu/faculty/lium/gai/wordLSTM.zip\n",
    "\n",
    "import pickle\n",
    "model.load_state_dict(torch.load(\"files/wordLSTM.pth\"))\n",
    "with open(\"files/word_to_int.p\",\"rb\") as fb:    \n",
    "    word_to_int = pickle.load(fb)      \n",
    "int_to_word={v:k for k,v in word_to_int.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "252eb61d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "def sample(model, prompt, length=200):\n",
    "    model.eval()\n",
    "    text = prompt.lower().split(' ')\n",
    "    hc = model.init_hidden(1)\n",
    "    length = length - len(text)\n",
    "    for i in range(0, length):\n",
    "        # if the text length is less than seq_len, use text to predict \n",
    "        if len(text)<=seq_len:\n",
    "            x = torch.tensor([[word_to_int[w] for w in text]])\n",
    "        # otherwise use the last seq_len tokens to predict\n",
    "        else:\n",
    "            x = torch.tensor([[word_to_int[w] for w in text[-seq_len:]]])            \n",
    "        inputs = x.to(device)\n",
    "        output, hc = model(inputs, hc)\n",
    "        logits = output[0][-1]\n",
    "        p = nn.functional.softmax(logits, dim=0).detach().cpu().numpy()\n",
    "        idx = np.random.choice(len(logits), p=p)\n",
    "        text.append(int_to_word[idx])\n",
    "    text=\" \".join(text)\n",
    "    for m in \",.:;?!$()/_&%*@'`\":\n",
    "        text=text.replace(f\" {m}\", f\"{m} \")\n",
    "    text=text.replace('\"  ', '\"')   \n",
    "    text=text.replace(\"'  \", \"'\")  \n",
    "    text=text.replace('\" ', '\"')   \n",
    "    text=text.replace(\"' \", \"'\")     \n",
    "    return text  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "1bbd71d2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "anna and the prince did not forget what he had not spoken.  when the softening barrier was not so long as he had talked to his brother,  all the hopelessness of the impression.  \"official tail,  a man who had tried him,  though he had been able to get across his charge and locked close,  and the light round the snow was in the light of the altar villa.  the article in law levin was first more precious than it was to him so that if it was most easy as it would be as the same.  this was now perfectly interested.  when he had got up close out into the sledge,  but it was locked in the light window with their one grass,  and in the band of the leaves of his projects,  and all the same stupid woman,  and really,  and i swung his arms round that thinking of bed.  a little box with the two boys were with the point of a gleam of filling the boy,  noiselessly signed the bottom of his mouth,  and answering them took the red\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "torch.manual_seed(42)\n",
    "np.random.seed(42)\n",
    "print(sample(model, prompt='Anna and the prince'))  "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8b838002",
   "metadata": {},
   "source": [
    "## 5.2\tTemperature and top-K sampling in text generation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "9d74ad57",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate(model, prompt, top_k=None, \n",
    "             length=200, temperature=1):\n",
    "    model.eval()\n",
    "    text = prompt.lower().split(' ')\n",
    "    hc = model.init_hidden(1)\n",
    "    length = length - len(text)    \n",
    "    for i in range(0, length):\n",
    "        # if the text length is less than seq_len, use text to predict \n",
    "        if len(text)<=seq_len:\n",
    "            x = torch.tensor([[word_to_int[w] for w in text]])\n",
    "        # otherwise use the last seq_len tokens to predict\n",
    "        else:\n",
    "            x = torch.tensor([[word_to_int[w] for w in text[-seq_len:]]])    \n",
    "        inputs = x.to(device)\n",
    "        output, hc = model(inputs, hc)\n",
    "        logits = output[0][-1]\n",
    "        # scale the logits with the temperature \n",
    "        logits = logits/temperature\n",
    "        p = nn.functional.softmax(logits, dim=0).detach().cpu()    \n",
    "        if top_k is None:\n",
    "            idx = np.random.choice(len(logits), p=p.numpy())\n",
    "        # top-K sampling\n",
    "        else:\n",
    "            ps, tops = p.topk(top_k)\n",
    "            ps=ps/ps.sum()\n",
    "            idx = np.random.choice(tops, p=ps.numpy())          \n",
    "        text.append(int_to_word[idx])\n",
    "    text=\" \".join(text)\n",
    "    for m in \",.:;?!$()/_&%*@'`\":\n",
    "        text=text.replace(f\" {m}\", f\"{m} \")\n",
    "    text=text.replace('\"  ', '\"')   \n",
    "    text=text.replace(\"'  \", \"'\")  \n",
    "    text=text.replace('\" ', '\"')   \n",
    "    text=text.replace(\"' \", \"'\")     \n",
    "    return text   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "c8ef9105",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "i'm not going to see you\n",
      "i'm not going to see those\n",
      "i'm not going to see me\n",
      "i'm not going to see you\n",
      "i'm not going to see her\n",
      "i'm not going to see her\n",
      "i'm not going to see the\n",
      "i'm not going to see my\n",
      "i'm not going to see you\n",
      "i'm not going to see me\n"
     ]
    }
   ],
   "source": [
    "# next token using default setting\n",
    "prompt=\"I ' m not going to see\"\n",
    "torch.manual_seed(42)\n",
    "np.random.seed(42)\n",
    "for _ in range(10):\n",
    "    print(generate(model, prompt, top_k=None, \n",
    "         length=len(prompt.split(\" \"))+1, temperature=1)) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "99cabd0a",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "i'm not going to see you\n",
      "i'm not going to see the\n",
      "i'm not going to see her\n",
      "i'm not going to see you\n",
      "i'm not going to see you\n",
      "i'm not going to see you\n",
      "i'm not going to see you\n",
      "i'm not going to see her\n",
      "i'm not going to see you\n",
      "i'm not going to see her\n"
     ]
    }
   ],
   "source": [
    "# next token using conservative predictions\n",
    "prompt=\"I ' m not going to see\"\n",
    "torch.manual_seed(42)\n",
    "np.random.seed(42)\n",
    "for _ in range(10):\n",
    "    print(generate(model, prompt, top_k=3, \n",
    "         length=len(prompt.split(\" \"))+1, temperature=0.5)) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "8a79a77f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "anna and the prince had no milk.  but,  \"answered levin,  and he stopped.  \"i've been skating to look at you all the harrows,  and i'm glad. . .  \"\"no,  i'm going to the country.  \"\"no,  it's not a nice fellow.  \"\"yes,  sir.  \"\"well,  what do you think about it?  \"\"why,  what's the matter?  \"\"yes,  yes,  \"answered levin,  smiling,  and he went into the hall.  \"yes,  i'll come for him and go away,  \"he said,  looking at the crumpled front of his shirt.  \"i have not come to see him,  \"she said,  and she went out.  \"i'm very glad,  \"she said,  with a slight bow to the ambassador's hand.  \"i'll go to the door.  \"she looked at her watch,  and she did not know what to say\n"
     ]
    }
   ],
   "source": [
    "torch.manual_seed(42)\n",
    "np.random.seed(42)\n",
    "print(generate(model, prompt='Anna and the prince',\n",
    "               top_k=3,\n",
    "               temperature=0.5)) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "49908224-0e83-4774-9379-faae55da7c1f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# exercise 8.4\n",
    "torch.manual_seed(0)\n",
    "np.random.seed(0)\n",
    "print(generate(model, prompt='Anna and the nurse',\n",
    "               top_k=10,\n",
    "               temperature=0.6)) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "53883f66",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "i'm not going to see them\n",
      "i'm not going to see scarlatina\n",
      "i'm not going to see behind\n",
      "i'm not going to see us\n",
      "i'm not going to see it\n",
      "i'm not going to see it\n",
      "i'm not going to see a\n",
      "i'm not going to see misery\n",
      "i'm not going to see another\n",
      "i'm not going to see seryozha\n"
     ]
    }
   ],
   "source": [
    "# next token using creative predictions\n",
    "prompt=\"I ' m not going to see\"\n",
    "torch.manual_seed(42)\n",
    "np.random.seed(42)\n",
    "for _ in range(10):\n",
    "    print(generate(model, prompt, top_k=None, \n",
    "         length=len(prompt.split(\" \"))+1, temperature=2)) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "18686bc4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "anna and the prince took sheaves covered suddenly people.  \"pyotr marya borissovna,  propped mihail though her son will seen how much evening her husband;  if tomorrow she liked great time too.  \"adopted heavens details for it women from this terrible,  admitting this touching all everything ill with flirtation shame consolation altogether:  ivan only all the circle with her honorable carriage in its house dress,  beethoven ashamed had the conversations raised mihailov stay of close i taste work?  \"on new farming show ivan nothing.  hat yesterday if interested understand every hundred of two with six thousand roubles according to women living over a thousand:  snetkov possibly try disagreeable schools with stake old glory mysterious one have people some moral conclusion,  got down and then their wreath.  darya alexandrovna thought inwardly peaceful with varenka out of the listen from and understand presented she was impossible anguish.  simply satisfied with staying after presence came where he pushed up his hand as marya her pretty hands into their quarters.  waltz was about the rider gathered;  sviazhsky further alone have an hand paused riding towards an exquisite\n"
     ]
    }
   ],
   "source": [
    "torch.manual_seed(42)\n",
    "np.random.seed(42)\n",
    "print(generate(model, prompt='Anna and the prince',\n",
    "               top_k=None,\n",
    "               temperature=2)) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "541d7525",
   "metadata": {},
   "outputs": [],
   "source": [
    "# exercise 8.5\n",
    "torch.manual_seed(0)\n",
    "np.random.seed(0)\n",
    "print(generate(model, prompt='Anna and the nurse',\n",
    "               top_k=10000,\n",
    "               temperature=2)) "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
